data <- read.csv("https://raw.githubusercontent.com/sameralzaim/Project-3/main/bodyPerformance.csv",
header = TRUE, sep = ",", stringsAsFactors = FALSE)
#head (data, n=10)
This data was collected by the Korean Sports Promotion Foundation. It consists of measurement data for each physical fitness measurement item. The data provides itemized measurement information of the National Physical Fitness Measurement Data managed by the Seoul Olympic Commemorative National Sports Promotion Corporation.
The Analysis is 2 folded: + Predicting number of Sit-ips an athlete can perform based on other body and performance measurs. + Drew and inference between body and performance measurement and the probability of the athlete being in top performing class (“A”). for that we will group the “Class” variable into 2 classes, “A” and “Non A” that combines “B”, “C” and “D”
The data consist of 13393 observations and 12 variables.
| Variable | Description | Class |
|---|---|---|
| age | Age between 20–64 | num |
| gender | F, M | chr |
| height_cm | Height in centimeters | num |
| weight_kg | Weight in kilograms | num |
| body.fat_. | Body fat index | num |
| diastolic | Blood pressure (diastolic) | num |
| systolic | Blood pressure (systolic) | num |
| gripForce | Measured in kg | num |
| sit.and.bend.forward_cm | Number of sit & bends per minute | num |
| sit.ups.counts | Number of sit-ups in 2 minutes | num |
| broad.jump_cm | Number of broad jumps | num |
| class | Fitness level (A: best, D: lowest) | chr |
No missing data in the database
[1] 0
overall high level look at the the data distribution, does not show any concerns with the data. while we will be looking at each variables in details late on, looking at the below tables, shows that males participant are almost double of the females participant. Also, we can see majority are in the age of 20-30 which expected since this athletic performance data. we can also see that sit-ups, board jumps and grip force are almost normally distributed while sit and bend concentrated between 40-50.
layout(matrix(1:12, nrow = 3), widths = c(1, 1, 1, 1))
# Create a frequency table Gender
gender_counts <- table(data$gender)
# Make the barplot and capture the midpoints of bars
bp <- barplot(gender_counts,
main = "Distributio by Gender",
col = "skyblue",
xlab = "Gender",
ylim = c(0, max(gender_counts) * 1.4)) # Add a little space for the labels
# Add the text labels (counts)
text(x = bp, y = gender_counts, labels = gender_counts, pos = 3, cex = 0.9)
# Create a frequency table
class_counts <- table(data$class)
# Make the barplot and capture the midpoints of bars
bp <- barplot(class_counts,
main = "Distribution by Class",
col = "skyblue",
xlab = "Class",
ylim = c(0, max(class_counts) * 1.4)) # Add a little space for the labels
# Add the text labels (counts)
text(x = bp, y = class_counts, labels = class_counts, pos = 3, cex = 0.9)
hist(data$body.fat, main="Bdy Fat", col="skyblue", xlab="Value", breaks=6)
hist(data$weight_kg, main="Weight KG", col="skyblue", xlab="Value", breaks=6)
hist(data$height_cm, main="Height CM", col="skyblue", xlab="Value", breaks=6)
hist(data$diastolic, main="Diastolic", col="skyblue", xlab="Value", breaks=6)
hist(data$systolic, main="Sysytolic", col="skyblue", xlab="Value", breaks=6)
hist(data$gripForce, main="gripForce", col="skyblue", xlab="Value", breaks=6)
hist(data$sit.and.bend.forward_cm, main="sit.and.bend.forward_cm", col="skyblue", xlab="Value", breaks=6)
hist(data$sit.ups.counts, main="sit.ups.counts", col="skyblue", xlab="Value", breaks=6)
hist(data$broad.jump_cm, main="broad.jump_cm", col="skyblue", xlab="Value", breaks=6)
hist(data$age, main="age", col="skyblue", xlab="Value", breaks=6)
🟦 Gender: Data more skewed toward higher males participants with males / females distribution at 2/3 to 1/3.
🟦 Class: Data equally distributed across the 4 classes.
🟦 Age: Data skewed toward younger participant with, almost 45% in the age group 20-30.
# Set up layout: matrix of one row, two columns with widths 2:1
layout(matrix(1:3, nrow = 1), widths = c(1, 1, 1))
# Create a frequency table Gender
gender_counts <- table(data$gender)
# Make the barplot and capture the midpoints of bars
bp <- barplot(gender_counts,
main = "Distributio by Gender",
col = "skyblue",
xlab = "Gender",
ylim = c(0, max(gender_counts) * 1.1)) # Add a little space for the labels
# Add the text labels (counts)
text(x = bp, y = gender_counts, labels = gender_counts, pos = 3, cex = 0.9)
# Create a frequency table
class_counts <- table(data$class)
# Make the barplot and capture the midpoints of bars
bp <- barplot(class_counts,
main = "Distribution by Class",
col = "skyblue",
xlab = "Class",
ylim = c(0, max(class_counts) * 1.1)) # Add a little space for the labels
# Add the text labels (counts)
text(x = bp, y = class_counts, labels = class_counts, pos = 3, cex = 0.9)
# Create a frequency table
age_counts <- table(data$age)
hist(data$age, main="Distribution by Age", col="skyblue", xlab="Age", breaks=5)
🟥 Body Fat: Typical range: ~15% to 35% with presence of outliers above 40% and some even above 75%. These outliers could reflect individuals with obesity or possibly data entry errors—worth checking.
🟦 Weight (kg): Typical range: ~55 kg to 85 kg. the data have outliers on both ends with lots of individuals above 100 kg — expected in populations with higher BMI and a few below 40 kg — might be very lean individuals or young participants
Consider stratifying by gender or age if these values seem inconsistent
🟪 Height (CM): Typical range: ~160 to 180 cm. A few values below 140 cm — This could be females or outliers (possibly even data error) and a couple near 195 cm.
🟩 Diastolic (mm Hg): Typical range: ~65 to 90 mm Hg. outliers with values below 30 and above 120 — potentially serious medical conditions or data entry errors. Diastolic values below 30 are quite rare physiologically
🟥 Systolic (mm Hg): Typical range: ~100 to 145 mm Hg with several very low values (< 50 mm Hg) and very high values (> 180 mm Hg). These may indicate hypotensive or hypertensive patients, or could be errors in data capture
layout(matrix(1:5, nrow = 1), widths = c(1, 1, 1, 1, 1))
boxplot(data$body.fat_., main="Body Fat", col="tomato")
boxplot(data$weight_kg, main="Weight (kg)", col="skyblue")
boxplot(data$height_cm, main="Height (CM)", col="purple1")
boxplot(data$diastolic, main="Diastolic", col="cyan4")
boxplot(data$systolic, main="systolic", col="brown2")
🟥 Grip Force : Median around 38–40 with the distribution being moderately spread (not skewed). A few outliers below (extremely low grip forces, possibly errors or true weak measurements).
🟪 Sit & Bend: Very narrow distribution, but some extreme outliers way above 200 and a few low outliers, possibly due to stiff individuals or misreporting. Possibly a skewed distribution.
🟦 Sit-ups Count: Median around 40. Fairly symmetric. A few low outliers (someone maybe stopped early or had an injury?).
🟩 Broad Jump (CM): Median near 200 with wide distribution..Quite a few outliers below 100, indicating either low performance or measurement issues.
🟧 Age: Median around 30 and distribution skewed slightly right. No putliers observed.
layout(matrix(1:5, nrow = 1), widths = c(1, 1, 1, 1, 1))
boxplot(data$gripForce, main="Grip Force", col="brown2")
boxplot(data$sit.and.bend.forward_cm, main="Sit & Bend", col="purple1")
boxplot(data$sit.ups.counts, main="Sit-ups Count", col="skyblue")
boxplot(data$broad.jump_cm, main="Broad Jump (CM)", col="cyan4")
boxplot(data$age, main="age", col="tomato")
🟧 Females (F): As expected, median weight is lower than for males with Majority lie roughly between ~50 kg and 70 kg.
There’s a wider spread of outliers on the high end (some women over 100 kg and a few low-weight outliers below 45 kg.
🟦 Males (M): Median weight around 70–75 kg with the middle 50% of male weights span ~65 kg to 90 kg. Males have more outliers overall — both low and high ends.
Overall, weight is right-skewed for both groups — many outliers appear on the high end of the weight scale, suggesting some heavier individuals pull the mean up. We would need to check the outliers for potential data entry errors or valid extreme cases.
plot_ly(data,
x = ~weight_kg,
y = ~gender,
type = "box",
color = ~gender,
colors = c("cyan4", "red"),
mode = 'lines+markers',
line = list(color = 'black'),
marker = list(color = 'black', size = 4))
# Customize colors here
To achieve that, need to build Regression Tree. We use the data we have to train an algorithm for predicting number of Sit-up an athlete can perform.
This process requires multiple steps as follows:
in this step we build the main tree utilizing the training dataset where we evaluate all possible featurs for spliting after deciding stopping rules that would tell our algorithm when to stop. In this process here, we used default rules stopping rules without tree pruning as follows: + Min observations in any node to perform a split is 20 + Min observations in terminal node 7 + Max depth of the tree 5 levels.
The above stopping rules we arrived at after attempting multiple combinations and values.
# Set seed for reproducibility
set.seed(123)
# Split data into training (70%) and test (30%) sets
train.index <- sample(1:nrow(data), size = 0.7 * nrow(data))
train.data <- data[train.index, ]
test.data <- data[-train.index, ]
# 1. Tree Induction & 2. Splitting Criteria
# Build the initial regression tree using rpart
tree.model <- rpart(sit.ups.counts ~ .,
data = train.data,
method = "anova", # For regression
control = rpart.control(
minsplit = 20, # 3. Stopping rule: min observations to split
minbucket = 10, # Min observations in terminal node
cp = seq(0, 0.05, 20), # Complexity parameter
maxdepth = 6 # Maximum tree depth
))
# Visualize the unpruned tree
rpart.plot(tree.model, main = "Initial Regression Tree")
The initial tree contains a long list of different pieces of information that can be used to improve the initial tree model. In order to improve the tree, wee look at the model complexity parameter cp and related errors and appropriately prune the initial tree.
The below complexity table shows key information we need for our tree pruning based on cross-validation:
# Examine cross-validation results
pander(tree.model$cptable)
| CP | nsplit | rel error | xerror | xstd |
|---|---|---|---|---|
| 0.3806 | 0 | 1 | 1 | 0.01406 |
| 0.09712 | 1 | 0.6194 | 0.6196 | 0.008808 |
| 0.04086 | 2 | 0.5223 | 0.5249 | 0.008041 |
| 0.02818 | 3 | 0.4814 | 0.4841 | 0.007443 |
| 0.02312 | 4 | 0.4532 | 0.4556 | 0.007362 |
| 0.02114 | 5 | 0.4301 | 0.428 | 0.007038 |
| 0.01233 | 6 | 0.409 | 0.4105 | 0.006745 |
| 0.01093 | 7 | 0.3966 | 0.4015 | 0.006746 |
| 0.009386 | 8 | 0.3857 | 0.39 | 0.0066 |
| 0.00809 | 9 | 0.3763 | 0.3789 | 0.006345 |
| 0.008 | 10 | 0.3682 | 0.376 | 0.006277 |
| 0.004837 | 11 | 0.3602 | 0.3634 | 0.006194 |
| 0.004823 | 12 | 0.3554 | 0.36 | 0.006143 |
| 0.004627 | 13 | 0.3506 | 0.3572 | 0.006126 |
| 0.004579 | 14 | 0.3459 | 0.3556 | 0.006121 |
| 0.004333 | 15 | 0.3414 | 0.3526 | 0.006112 |
| 0.00428 | 16 | 0.337 | 0.3484 | 0.006062 |
| 0.004116 | 17 | 0.3327 | 0.344 | 0.00603 |
| 0.003791 | 18 | 0.3286 | 0.3383 | 0.005948 |
| 0.003518 | 19 | 0.3248 | 0.334 | 0.005891 |
| 0.003374 | 20 | 0.3213 | 0.3302 | 0.005706 |
| 0.002775 | 21 | 0.3179 | 0.3253 | 0.005537 |
| 0.002464 | 22 | 0.3152 | 0.3232 | 0.005549 |
| 0.002348 | 23 | 0.3127 | 0.3227 | 0.005548 |
| 0.002268 | 24 | 0.3104 | 0.3218 | 0.005534 |
| 0.002094 | 25 | 0.3081 | 0.3197 | 0.00547 |
| 0.00203 | 26 | 0.306 | 0.3175 | 0.005438 |
| 0.001464 | 27 | 0.304 | 0.3122 | 0.005348 |
| 0.001399 | 28 | 0.3025 | 0.3118 | 0.005359 |
| 0.00137 | 29 | 0.3011 | 0.312 | 0.005379 |
| 0.001365 | 30 | 0.2997 | 0.3119 | 0.005379 |
| 0.001276 | 31 | 0.2984 | 0.3117 | 0.005381 |
| 0.001141 | 32 | 0.2971 | 0.3098 | 0.005385 |
| 0.001099 | 33 | 0.296 | 0.3081 | 0.005377 |
| 0.00108 | 34 | 0.2949 | 0.3074 | 0.005372 |
| 0.001028 | 35 | 0.2938 | 0.3073 | 0.005376 |
| 0.0009767 | 36 | 0.2927 | 0.3064 | 0.005371 |
| 0.0008478 | 37 | 0.2918 | 0.3057 | 0.005346 |
| 0.000781 | 38 | 0.2909 | 0.3068 | 0.005418 |
| 0.0007593 | 39 | 0.2901 | 0.3066 | 0.005412 |
| 0.0007381 | 40 | 0.2894 | 0.3066 | 0.005411 |
| 0.0006773 | 41 | 0.2886 | 0.3067 | 0.005416 |
| 0.0006724 | 42 | 0.288 | 0.3071 | 0.005426 |
| 0.0006642 | 43 | 0.2873 | 0.3074 | 0.005427 |
| 0.0006458 | 44 | 0.2866 | 0.3076 | 0.005431 |
| 0.0006012 | 45 | 0.286 | 0.3074 | 0.005436 |
| 0.0005739 | 46 | 0.2854 | 0.3078 | 0.00549 |
| 0.0005654 | 47 | 0.2848 | 0.3081 | 0.005491 |
| 0.0005406 | 49 | 0.2837 | 0.3081 | 0.005494 |
| 0.0005262 | 50 | 0.2831 | 0.308 | 0.005486 |
| 0.0005177 | 51 | 0.2826 | 0.3083 | 0.005496 |
| 0.0005174 | 52 | 0.2821 | 0.3079 | 0.005489 |
| 0.000505 | 53 | 0.2816 | 0.3078 | 0.005487 |
| 0.000464 | 54 | 0.2811 | 0.3072 | 0.005483 |
| 0.0003884 | 55 | 0.2806 | 0.3067 | 0.005482 |
| 0.0003803 | 56 | 0.2802 | 0.3065 | 0.005481 |
| 0.0003274 | 57 | 0.2798 | 0.3064 | 0.005463 |
| 0.0003204 | 58 | 0.2795 | 0.3068 | 0.00547 |
| 0.0002966 | 59 | 0.2792 | 0.307 | 0.005471 |
| 0.0002495 | 60 | 0.2789 | 0.3072 | 0.005477 |
| 0 | 61 | 0.2786 | 0.3071 | 0.005497 |
from the above, we can see that min Cross validation error is 0.3178 with number of splits of 28 abd cp of 0.0007478.
Smallest nsplit where xerror is within 1 standard error (xstd) of the minimum CP. in the table represent 0.3196 with nsplit of 25. We need to identify the largest CP where xerror is within 1 standard error of the minimum (to balance simplicity and accuracy).
plotcp(tree.model)
As mentioned earlier, we select the largest cp where xerror is within 1 standard error of the minimum (to balance simplicity and accuracy). based on this our best possible tree would have 25 nodes and complexity
cp.table <- tree.model$cptable
## Identify the minimum `xerror` and its `cp`.
min.xerror <- min(cp.table[, "xerror"])
min.cp.row <- which.min(cp.table[, "xerror"])
min.cp <- cp.table[min.cp.row, "CP"]
## Get the standard error (`xstd`) of the minimum `xerror`
xerror.std <- cp.table[min.cp.row, "xstd"]
threshold <- min.xerror + xerror.std # Upper bound (1 SE rule)
## Find the simplest tree (`cp`) Where `xerror less than or equal to Threshold`.
best.cp.row <- which(cp.table[, "xerror"] <= threshold)[1] # First row meeting criteria
best.cp <- cp.table[best.cp.row, "CP"]
## Two different trees: best CP vs minimum CP
pruned.tree.best.cp <- prune(tree.model, cp = best.cp)
pruned.tree.min.cp <- prune(tree.model, cp = min.cp)
The above tree shows that our best cp = 0.0014 where the tree has 26 splits and cross validation error 0.3196
# Visualize the pruned tree: best CP
rpart.plot(pruned.tree.best.cp, main = paste("Pruned Tree (Best CP): cp = ", round(best.cp,4)))
The above min cp tree shows 30 splits but with higher complexity and lower increase in tree strength
# Visualize the pruned tree: minimum CP
rpart.plot(pruned.tree.min.cp, main = paste("Pruned Tree (Minimum CP): cp = ", round(min.cp,4)))
Next, we use the final pruned regression tree to make predictions. Since only five features “abody.fat_.” + “gender” + “broad.jump_cm” + “class” + “age” were used in the algorithm.
As next step: we use the pruned regression tree with best and min cp to make predictions and since we did not use all variables, we fit fit two linear regression models and compare the performance of the three models.
The 2 linear regression models were built as follows:
# 5. Prediction
# Make predictions on test data
pred.best.cp <- predict(pruned.tree.best.cp, newdata = test.data)
pred.min.cp <- predict(pruned.tree.min.cp, newdata = test.data)
# Evaluate model performance: best.cp
mse.tree.best.cp <- mean((test.data$sit.ups.counts - pred.best.cp)^2)
rmse.tree.best.cp <- sqrt(mse.tree.best.cp)
r.squared.tree.best.cp <- cor(test.data$sit.ups.counts, pred.best.cp)^2
# min.cp
mse.tree.min.cp <- mean((test.data$sit.ups.counts - pred.min.cp)^2)
rmse.tree.min.cp <- sqrt(mse.tree.min.cp)
r.squared.tree.min.cp <- cor(test.data$sit.ups.counts, pred.min.cp)^2
##
# fit ordinary least square regression
LSE01 <- lm(sit.ups.counts ~ body.fat_. + gender + broad.jump_cm + class + age, data = train.data)
pred.lse01 <- predict(LSE01, newdata = test.data)
mse.lse01 <- mean((test.data$sit.ups.counts - pred.lse01)^2)
rmse.lse01 <- sqrt(mse.lse01)
r.squared.lse01 <- cor(test.data$sit.ups.counts, pred.lse01)^2
##
## ordinary LSE regression model with step-wise variable selection
lse02.fit <- lm(sit.ups.counts~.,data = train.data)
AIC.fit <- stepAIC(lse02.fit, direction="both", trace = FALSE)
pred.lse02 <- predict(AIC.fit, test.data)
mse.lse02 <- mean((test.data$sit.ups.counts - pred.lse02)^2) # mean square error
rmse.lse02 <- sqrt(mse.lse02) # root mean square error
r.squared.lse02 <- (cor(test.data$sit.ups.counts, pred.lse02))^2 # r-squared
###
Errors <- cbind(MSE = c(mse.tree.best.cp, mse.tree.min.cp, mse.lse01, mse.lse02),
RMSE = c(rmse.tree.best.cp, rmse.tree.min.cp, rmse.lse01, rmse.lse02),
r.squared = c(r.squared.tree.best.cp, r.squared.tree.min.cp, r.squared.lse01, r.squared.lse02))
rownames(Errors) = c("tree.best.cp", "tree.min.cp", "lse01", "lse02")
pander(Errors)
| MSE | RMSE | r.squared | |
|---|---|---|---|
| tree.best.cp | 61.85 | 7.864 | 0.7023 |
| tree.min.cp | 61.21 | 7.823 | 0.7054 |
| lse01 | 54.52 | 7.384 | 0.7376 |
| lse02 | 53.14 | 7.29 | 0.7442 |
unlike the expectation, it seems that the linear regression model out perform the tree where we compare “lse2” with “min cp”.
| Model | MSE ↓ | R² ↑ | Interpretation |
|---|---|---|---|
| lse02 | 53.14 | 0.744 | Lower error + captures ~74% of variance |
| tree.best.cp | 61.85 | 0.702 | Simpler, but not as accurate |
Hence, even with cp tuning, trees have a ceiling in how well they can model smooth, continuous data. Regression models, particularly when well-specified, often just do better for that kind of problem.
However, as we compare both outcome, we can well utilize the tree if needed, as it continue to provide strong separation with much less complexity.
Variable importance in regression trees identifies which predictors have the strongest influence on the target variable’s predictions
Comparing variables importance between minimum cp and best cp shos that both have the same dtributions as outlined in the below 2 graphs.
importance <- pruned.tree.best.cp$variable.importance
barplot(sort(importance, decreasing = TRUE),
main = "Variable Importance: Best CP",
col = "skyblue",
las = 2)
importance <- pruned.tree.min.cp$variable.importance
barplot(sort(importance, decreasing = TRUE),
main = "Variable Importance: Minimum CP",
col = "skyblue",
las = 2)
looking at the above tree we can see that some variables have higher importance but not showing in the tree. It is not uncommon that some variables in the variable importance list but not shown in the final regression and classification trees.
Importance is measured by performance gain (e.g., reduction in MSE/Gini impurity).This captures how much a variable improves model accuracy but does not imply statistical significance.
Larger coefficients indicate stronger associations with the outcome, though correlation does not imply causation and as such as height and grip force more linked to variable like gender where males have taller and stronger grip force but this confound with gender and since tree used gender then these are becoming redundent variables in tree build.
For comparison, we also print out the inferential table of the step-wise linear regression model in the following.
pander(summary(AIC.fit)$coef)
| Estimate | Std. Error | t value | Pr(>|t|) | |
|---|---|---|---|---|
| (Intercept) | 57.88 | 3.041 | 19.03 | 2.893e-79 |
| age | -0.397 | 0.007227 | -54.93 | 0 |
| genderM | 5.575 | 0.3692 | 15.1 | 6.609e-51 |
| height_cm | -0.1332 | 0.01883 | -7.074 | 1.61e-12 |
| weight_kg | 0.1463 | 0.01469 | 9.96 | 2.962e-23 |
| body.fat_. | -0.2971 | 0.02023 | -14.68 | 2.798e-48 |
| diastolic | 0.02365 | 0.007704 | 3.07 | 0.002146 |
| gripForce | 0.04457 | 0.01554 | 2.868 | 0.004141 |
| sit.and.bend.forward_cm | 0.01985 | 0.01243 | 1.598 | 0.1102 |
| broad.jump_cm | 0.07872 | 0.00407 | 19.34 | 9.629e-82 |
| classB | -4.209 | 0.2255 | -18.66 | 2.422e-76 |
| classC | -7.467 | 0.2436 | -30.66 | 8.64e-197 |
| classD | -13.68 | 0.31 | -44.14 | 0 |
We would recommend using the tree though the performance of the linear regression model outperform the tree since we didn’t not remove variables that have strong correlation but not causation relationship with the treated variable such as “broad.jump_cm”
To predict and calculate class variable, we use Classification trees. Classification trees are a type of supervised learning algorithm that recursively partitions the feature space to predict categorical target variables.
The initial tree size is controlled by some default hyper-parameters “rpart.control()”. It tends to be over-fitted.
While it is common in classification that the key challenge is that the classes (categories) are not equally represented.however, in our dataset the classes are equally distributed across the 4 classes. howevr. since we have 4 levels in “class” variable, we combine these to pridict the porbability of athelet being classified as “A” or “Not A”.
A we are building the tree, we start with the entire dataset at the root node and then recursively split the data into purer subsets
Optimal Tree Size: Typically where xerror is minimized
# Step 1: Recode the target variable into binary
data$binary_class <- ifelse(data$class == "A", 1, 0)
data$binary_class <- as.factor(data$binary_class)
# Step 2: Split data into training and test sets
set.seed(123)
train.index <- createDataPartition(data$binary_class, p = 0.7, list = FALSE)
train.data <- data[train.index, ]
test.data <- data[-train.index, ]
# Step 3: Fit the classification tree using the new binary target
tree.model <- rpart(binary_class ~ .,
data = train.data[, !(names(train.data) %in% "class")], # drop original class
method = "class", # classification tree
parms = list(split = "gini", # Using Gini index
# FN cost = 1, FP cost = 0.5
loss = matrix(c(0, 0.5, 1, 0), nrow = 2)
),
control = rpart.control(minsplit = 20, # Min 15 obs to split
minbucket = 10, # Min 7 obs in leaf
# Complexity parameter
cp = 0.001, # complex parameter
maxdepth = 7)) # Max tree depth
rpart.plot(tree.model,
extra = 104, # check the help document for more information
# color palette is a sequential color scheme that blends green (G) to blue (Bu)
box.palette = "GnBu",
branch.lty = 1,
shadow.col = "gray",
nn = TRUE)
#Print the complexity parameter table
pander(tree.model$cptable)
| CP | nsplit | rel error | xerror | xstd |
|---|---|---|---|---|
| 0.02005 | 0 | 1 | 2 | 0.03578 |
| 0.01152 | 7 | 0.8332 | 1.384 | 0.03044 |
| 0.009954 | 8 | 0.8217 | 1.32 | 0.02978 |
| 0.007821 | 12 | 0.779 | 1.297 | 0.02959 |
| 0.004266 | 15 | 0.7555 | 1.27 | 0.02934 |
| 0.004124 | 16 | 0.7513 | 1.253 | 0.02918 |
| 0.002474 | 19 | 0.7389 | 1.245 | 0.02909 |
| 0.002453 | 29 | 0.7142 | 1.224 | 0.02881 |
| 0.001706 | 33 | 0.7044 | 1.202 | 0.0285 |
| 0.00128 | 34 | 0.7026 | 1.183 | 0.02827 |
| 0.001 | 35 | 0.7014 | 1.179 | 0.02815 |
The below graph gives the reference line (broken line) for 1-SE rule. The numbers on the top of the plot represent the leaf nodes in the final tree diagram.
# Print cp table
#printcp(tree.model)
# Plot cp vs cross-validated error
plotcp(tree.model)
For clarity in the analysis, we introduce two notations: min.cp represents the cp value yielding the minimum cross-validation error, while 1SE.cp denotes the cp value selected by the more conservative 1-SE rule (minimal error plus one standard error).
The below-pruned tree diagram is based on the 1-SE rule. Next, we plot the tree diagram based on the minimum cross-validation error.
# Find the optimal cp value that minimizes cross-validated error
min.cp <- tree.model$cptable[which.min(tree.model$cptable[,"xerror"]),"CP"]
# Prune the tree using the optimal cp
pruned.tree.1SE <- prune(tree.model, cp = 0.001)
pruned.tree.min <- prune(tree.model, cp = min.cp)
# Visualize the pruned tree
rpart.plot(pruned.tree.1SE,
extra = 104, # check the help document for more information
# color palette is a sequential color scheme that blends green (G) to blue (Bu)
box.palette = "GnBu",
branch.lty = 1,
shadow.col = "gray",
nn = TRUE,
main = "Pruned Classification Tree (1-SE Rule)")
The above-pruned tree diagram is based on the 1-SE rule. Next, we plot the tree diagram based on the minimum cross-validation error
# Visualize the pruned tree
rpart.plot(pruned.tree.min,
extra = 104, # check the help document for more information
# color palette is a sequential color scheme that blends green (G) to blue (Bu)
box.palette = "GnBu",
branch.lty = 1,
shadow.col = "gray",
nn = TRUE,
main = "Pruned Classification Tree (Min Cross Validation)")
Classification trees make predictions by routing observations through a series of hierarchical splits, starting at the root node and ending at a terminal leaf node. Each split applies a decision rule based on feature values.
# Make predictions on the test set
pred.label.1SE <- predict(pruned.tree.1SE, test.data, type = "class") # default cutoff 0.5
pred.prob.1SE <- predict(pruned.tree.1SE, test.data, type = "prob")[,2]
##
pred.label.min <- predict(pruned.tree.min, test.data, type = "class") # default cutoff 0.5
pred.prob.min <- predict(pruned.tree.min, test.data, type = "prob")[,2]
# Confusion matrix
#conf.matrix <- confusionMatrix(pred.label, test.data$diabetes, positive = "pos")
#print(conf.matrix)
train.data$binary_class <- data$binary_class[train.index]
test.data$binary_class <- data$binary_class[-train.index]
# Remove class column if it still exists
train.data$class <- NULL
test.data$class <- NULL
########################
### logistic regression
logit.fit <- glm(binary_class ~ ., data = train.data, family = binomial)
AIC.logit <- step(logit.fit, direction = "both", trace = 0)
pred.logit <- predict(AIC.logit, test.data, type = "response")
# ROC curve and AUC
roc.tree.1SE <- roc(test.data$binary_class, pred.prob.1SE)
roc.tree.min <- roc(test.data$binary_class, pred.prob.min)
roc.logit <- roc(test.data$binary_class, pred.logit)
##
### Sen-Spe
tree.1SE.sen <- roc.tree.1SE$sensitivities
tree.1SE.spe <- roc.tree.1SE$specificities
#
tree.min.sen <- roc.tree.min$sensitivities
tree.min.spe <- roc.tree.min$specificities
#
logit.sen <- roc.logit$sensitivities
logit.spe <- roc.logit$specificities
## AUC
auc.tree.1SE <- roc.tree.1SE$auc
auc.tree.min <- roc.tree.min$auc
auc.logit <- roc.logit$auc
###
plot(1-logit.spe, logit.sen,
xlab = "1 - specificity",
ylab = "sensitivity",
col = "darkred",
type = "l",
lty = 1,
lwd = 1,
main = "ROC: CART and Logistic Regressopm")
lines(1-tree.1SE.spe, tree.1SE.sen,
col = "blue",
lty = 1,
lwd = 1)
lines(1-tree.min.spe, tree.min.sen,
col = "orange",
lty = 1,
lwd = 1)
abline(0,1, col = "skyblue3", lty = 2, lwd = 2)
legend("bottomright", c("Logistic", "Tree 1SE", "Tree Min"),
lty = c(1,1,1), lwd = rep(1,3),
col = c("red", "blue", "orange"),
bty="n",cex = 0.8)
## annotation - AUC
text(0.8, 0.46, paste("Logistic AUC: ", round(auc.logit,4)), cex = 0.8)
text(0.8, 0.4, paste("Tree 1SE AUC: ", round(auc.tree.1SE,4)), cex = 0.8)
text(0.8, 0.34, paste("Tree Min AUC: ", round(auc.tree.min,4)), cex = 0.8)
The ROC curves and corresponding AUC values demonstrate that the logistic regression model achieves marginally better performance compared to both pruned tree models, with the more complex tree (pruned using minimum cross-validation error) same predictive ability than the simpler tree pruned according to the 1-SE rule.
In binary classification, predicted probabilities must be converted into class labels (e.g., 0 or 1) by applying a cut-off threshold. The choice of this threshold significantly impacts model performance, as it balances accuracy, sensitivity (recall), and specificity.
Key approaches to determine the optimal cut-off:
I. Trade-off Between Sensitivity and Specificity II. Accuracy-Driven Cut-off III. Cost-Sensitive Threshold IV. ROC and Precision-Recall Curves V. Practical Considerations
# Predictive probabilities of the pruned tree
pred.prob.min <- predict(pruned.tree.min, train.data, type = "prob")[, 2]
# Cutoff values
cutoff <- seq(0, 1, length = 10)
cost <- numeric(length(cutoff))
# Misclassification cost for each cutoff
for (i in seq_along(cutoff)) {
pred.label <- ifelse(pred.prob.min > cutoff[i], 1, 0)
FN <- sum(pred.label == 0 & train.data$binary_class == 1)
FP <- sum(pred.label == 1 & train.data$binary_class == 0)
cost[i] <- 5 * FP + 20 * FN
}
# Optimal cutoff
min.ID <- which(cost == min(cost))
optim.prob <- mean(cutoff[min.ID])
# Plot
plot(cutoff, cost, type = "b", col = "navy",
main = "Cutoff vs Misclassification Cost",
xlab = "Cutoff", ylab = "Cost")
text(optim.prob, min(cost) + 20000,
paste("Optimal cutoff:", round(optim.prob, 3)),
cex = 0.8, col = "darkred")
The resulting optimal cut-off probability of 0.222, displayed on the plot above, will be used to make predictions on the test dataset, and the corresponding accuracy will be reported. We emphasize once again that this optimal threshold is chosen specifically to minimize the total cost of misclassification.